{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import gluon, autograd, nd\n",
    "import numpy as np\n",
    "import pickle as p\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "%matplotlib inline\n",
    "ctx = mx.gpu()\n",
    "data_dir = '/home/sinyer/python/data/cifar10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_cifar(route = data_dir+'/cifar-10-batches-py'):\n",
    "    def load_batch(filename):\n",
    "        with open(filename, 'rb')as f:\n",
    "            data_dict = p.load(f, encoding='latin1')\n",
    "            X = data_dict['data']\n",
    "            Y = data_dict['labels']\n",
    "            X = X.reshape(10000, 3, 32,32).astype(\"float\")\n",
    "            Y = np.array(Y)\n",
    "            return X, Y\n",
    "    def load_labels(filename):\n",
    "        with open(filename, 'rb') as f:\n",
    "            label_names = p.load(f, encoding='latin1')\n",
    "            names = label_names['label_names']\n",
    "            return names\n",
    "    label_names = load_labels(route + \"/batches.meta\")\n",
    "    x1, y1 = load_batch(route + \"/data_batch_1\")\n",
    "    x2, y2 = load_batch(route + \"/data_batch_2\")\n",
    "    x3, y3 = load_batch(route + \"/data_batch_3\")\n",
    "    x4, y4 = load_batch(route + \"/data_batch_4\")\n",
    "    x5, y5 = load_batch(route + \"/data_batch_5\")\n",
    "    test_pic, test_label = load_batch(route + \"/test_batch\")\n",
    "    train_pic = np.concatenate((x1, x2, x3, x4, x5))\n",
    "    train_label = np.concatenate((y1, y2, y3, y4, y5))\n",
    "    return train_pic, train_label, test_pic, test_label\n",
    "\n",
    "def accuracy(output, label):\n",
    "    return nd.mean(output.argmax(axis=1)==label).asscalar()\n",
    "\n",
    "def evaluate_accuracy(data_iterator, net, ctx):\n",
    "    acc = 0.\n",
    "    for data, label in data_iterator:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        acc += accuracy(output, label)\n",
    "    return acc / len(data_iterator)\n",
    "\n",
    "def net(X):\n",
    "    h1_conv = nd.Convolution(data=X, weight=W1, bias=b1, kernel=W1.shape[2:], num_filter=W1.shape[0])\n",
    "    h1_activation = nd.relu(h1_conv)\n",
    "    h1 = nd.Pooling(data=h1_activation, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
    "    h2_conv = nd.Convolution(data=h1, weight=W2, bias=b2, kernel=W2.shape[2:], num_filter=W2.shape[0])\n",
    "    h2_activation = nd.relu(h2_conv)\n",
    "    h2 = nd.Pooling(data=h2_activation, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
    "    h2 = nd.flatten(h2)\n",
    "    h3_linear = nd.dot(h2, W3) + b3\n",
    "    h3 = nd.relu(h3_linear)\n",
    "    h4_linear = nd.dot(h3, W4) + b4\n",
    "    return h4_linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pic, train_label, test_pic, test_label = load_cifar()\n",
    "\n",
    "batch_size = 128\n",
    "train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    train_pic.astype('float32')/255, train_label.astype('float32')), batch_size, shuffle=True)\n",
    "test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    test_pic.astype('float32')/255, test_label.astype('float32')), batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "weight_scale = .02\n",
    "\n",
    "W1 = nd.random_normal(shape=(20,3,5,5), scale=weight_scale, ctx=ctx)\n",
    "b1 = nd.zeros(W1.shape[0], ctx=ctx)\n",
    "W2 = nd.random_normal(shape=(50,20,3,3), scale=weight_scale, ctx=ctx)\n",
    "b2 = nd.zeros(W2.shape[0], ctx=ctx)\n",
    "W3 = nd.random_normal(shape=(1800, 128), scale=weight_scale, ctx=ctx)\n",
    "b3 = nd.zeros(W3.shape[1], ctx=ctx)\n",
    "W4 = nd.random_normal(shape=(W3.shape[1], 10), scale=weight_scale, ctx=ctx)\n",
    "b4 = nd.zeros(W4.shape[1], ctx=ctx)\n",
    "\n",
    "params = [W1, b1, W2, b2, W3, b3, W4, b4]\n",
    "\n",
    "for param in params:\n",
    "    param.attach_grad()\n",
    "\n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E 0; L 2.302505; Tr_acc 0.101570; Te_acc 0.106309; T 2.100871\n",
      "E 10; L 1.458303; Tr_acc 0.478197; Te_acc 0.493078; T 1.983325\n",
      "E 20; L 1.099228; Tr_acc 0.614246; Te_acc 0.600376; T 1.798058\n",
      "E 30; L 0.843323; Tr_acc 0.706014; Te_acc 0.648734; T 1.789720\n",
      "E 40; L 0.630558; Tr_acc 0.782449; Te_acc 0.662975; T 1.787690\n",
      "E 50; L 0.426327; Tr_acc 0.852142; Te_acc 0.662876; T 1.792411\n",
      "E 60; L 0.249280; Tr_acc 0.916017; Te_acc 0.657140; T 1.819582\n",
      "E 70; L 0.112527; Tr_acc 0.968147; Te_acc 0.661392; T 1.894498\n",
      "E 80; L 0.019670; Tr_acc 0.998569; Te_acc 0.651800; T 1.882257\n",
      "E 90; L 0.009770; Tr_acc 0.999840; Te_acc 0.658228; T 1.958185\n",
      "Tr_acc 0.999940; Te_acc 0.657041\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3Xl8VPW9//HXNwnZyQIkEBIg7Psi\nRBBU6i6Iyq1aq7bWtlbu/VWs1dal1UcXb9tbrdXaXmvFYrVa5WrdUHFFRXFhkX0nhCUrCUsWErLM\nzPf3x3fQiEEGSBjmzPv5eMwjOWdOZj4nZ+Z9vvOdc77HWGsRERFviQl3ASIi0v4U7iIiHqRwFxHx\nIIW7iIgHKdxFRDxI4S4i4kGHDXdjzKPGmEpjzJpD3G+MMX82xhQaY1YZY8a2f5kiInIkQmm5PwZM\n+Yr7pwIDg7cZwEPHXpaIiByLw4a7tfZ9YM9XLDId+Kd1PgEyjDE57VWgiIgcubh2eIxcoLjVdElw\nXvnBCxpjZuBa96SkpIwbMmRIOzy9iJxomn0BGlv8tAQsPn8AX8Di81ta/AH8AUvAWvwBS7SeH5+b\nkUSXlPij+ttPP/10l7U263DLtUe4mzbmtbnNrLWzgFkABQUFdunSpe3w9CJyIvAHLAs2VfL4R9tZ\nsKnqs/lxBrqnxNMtNYGszglkJsfTOTGO1MQ4kjvFER8XQ3xcDAlxMWQkdyIjKZ7UxDhijcEYMAZi\nY8xn060jx7SRPtZ+cX7rRUxbfxBcxhgwmC/87cGjs7R1n8USE7zDmFbzbdv1GQMZyfGkJhxd/Bpj\ntoeyXHuEewnQq9V0HlDWDo8rIhGiqGofP/zXMjZU1JHdOYEbzx7I2UOz6Z6WSNeUeOJidWDe8dYe\n4T4XmGmMmQNMAGqstV/qkhERb5q/fic/nrOCuFjDA1eM4YKROXRSmIfdYcPdGPM0cAbQzRhTAvwS\n6ARgrf0bMA+4ACgEGoDvdVSxInL8VTc088u5a9m9r5m4WENcTAzJ8bGkJsbR1BLguWUljMhN42/f\nHkdeZnK4y5Wgw4a7tfbKw9xvgevbrSIROWHsb/bz/ceWsKa0lhG5afgClmZfgP0tfuqbfDQ0+7m8\nII+7po8gsVNsuMuVVtqjW0ZEIpy1lk0797F0+x4m9utKv6xUfP4ANzy9jOXF1Tx41VguGKkjnCOJ\nwl0kim3eWccrq8p5dXU5hZX7Pps/tncGmcnxzN9QyX9PH65gj0AKd5Eo4vMHWFNWy4KNVcxbXc7G\nnXUYAxP6duGaSSMYn9+F9zZW8tyyEpbtqGbmmQO4emJ+uMuWo6BwF/G4kr0NvLl2J+9tquLTbXuo\nb/ZjDJzcpwu/vng4U0f0IDst8bPlB/fozIzJ/dhZ20T3tIQwVi7HQuEu4kG1jS08s6SYF5aXsras\nFoD+WSlcMjaPCf26MKFvV7I6Hzq4jTH0SE885P1y4lO4i3hI8Z4GHvmgiH9/WkJDs5/RvTL42dQh\nnDe8B327pYS7PDmOFO4iHrCvyceD7xYy+4OtAFw0uiffnZTPyLz0MFcm4aJwF4lglbWNvLyqnL8t\n2EJVXROXjM3l1vOHqEtFFO4ikWZfk4+XV5bx4vJSFm/bg7Xu0MVZV4/jpN6Z4S5PThAKd5EIsbas\nhic/2cHcFaXUN/vpl5XCj84ayEWjcxiQ3Tnc5ckJRuEucoJbtmMvf5m/mXc3VpHYKYaLRvXkygm9\nOalXxiGHsBVRuIucYPbUN7Nk2x6W7djL4q17WL6jmszkTtxy/mC+fUof0pM6hbtEiQAKd5ETRLMv\nwKz3t/Dndwpp9gXoFGsY1jOdn18whG9N6EPKUV7cQaKTXi0iYdbiD7Bk6x5+9fJaNu3cx7SROXz/\ntHyG90zXSIty1BTuImFQWLmPxz7aysriGjburKPZFyA3I4nZ1xRw9tDu4S5PPEDhLnIcFVXt48/z\nNzN3ZRkJcbGM65PJ9yblM6xnGucM7a6uF2k3eiWJHAdVdU3c//Ym5izeQUJcLNdN7seM0/vRNVUD\nc0nHULiLdBBrLdt3NzB3ZRkPL9hCky/A1af0YeZZA79y0C6R9qBwF2lH1loWFu7i35+W8EnRbnbW\nNgFw/vDu3DZlCP2yUsNcoUQLhbtIO/D5A7y6upyHFxSxrryWrinxTBrQjQl9uzCxf1f6K9TlOFO4\nixyj5Tv2cueLa1hbVkv/rBTuuXQU00/qSUKcDmOU8FG4ixylvfXN/OHNjTy9eAfZnRP4y5UnMW1k\nDjExGhJAwk/hLnKEdu9r4pEPtvLEx9vY3+Lne5P6ctO5A+mcqGEB5MShcBcJUVVdE498UMQTH2+n\n0efnwlE9ueGsAQzqrhEZ5cSjcBc5jLLq/Ty6cCtPLtpOsy/A9DG5XH/mAAZk60tSOXEp3EXaUFnb\nyNyVZcxbXc6yHdXEGPiPk3KZeeYAHc4oEUHhLtJKRU0jD71XyNOLi2n2BxiWk8ZPzxvExaNz6d01\nOdzliYRM4S6C+5L0f98t5F+LdhAIWC4bl8d1k/vp+HSJWAp3iWr1TT5mL9zKwwu20OgLcNnYPGae\nNYBeXdRKl8imcJeoU9/k472NVbyxtoJ3N1RS1+RjyvAe/PT8wfqSVDxD4S5R5Z0NO/nJMyvZ29BC\nl5R4po7swRXjezO2d2a4SxNpVwp3iQot/gD3vrGRh98vYlhOGn/91jhOzs8kLjYm3KWJdAiFu3ha\nbWMLr6+p4ImPt7O6tIZvn9KbO6cN0+XrxPMU7uJJ23fXc88bG3lr3U6afQF6dUniL1eexEWje4a7\nNJHjQuEuntLiD/D3D7byp7c30Sk2hqvG92b6mJ6M6ZWBMRrQS6JHSOFujJkCPADEAn+31v7+oPt7\nA48DGcFlbrfWzmvnWkUOaV+Tj3mrynn0w61sqKjj/OHd+fXFI+iRnhju0kTC4rDhboyJBR4EzgVK\ngCXGmLnW2nWtFrsTeMZa+5AxZhgwD8jvgHpFvqCsej9/ensTr6wqp6HZT7+sFP727XFMGdEj3KWJ\nhFUoLffxQKG1tgjAGDMHmA60DncLpAV/TwfK2rNIkYNZa3l6cTG/m7ceXyDA9NG5XH5yHmN7Z6r7\nRYTQwj0XKG41XQJMOGiZXwFvGmNuAFKAc9p6IGPMDGAGQO/evY+0VolyNftbKKysY/POfcxdWcZH\nW3YzsV9X7r50lMZ9ETlIKOHeVjPIHjR9JfCYtfaPxpiJwBPGmBHW2sAX/sjaWcAsgIKCgoMfQ6RN\ne+qb+c2r63h+Weln89IS4/jNf4zgqvG9deUjkTaEEu4lQK9W03l8udvlWmAKgLX2Y2NMItANqGyP\nIiU6WWt5blkpv311HXWNPn5wWl8m9u/KwOzO5GUmKdRFvkIo4b4EGGiM6QuUAlcAVx20zA7gbOAx\nY8xQIBGoas9CJbqsK6vlV3PXsnjbHsb1yeR/LhmpKx6JHIHDhru11meMmQm8gTvM8VFr7VpjzF3A\nUmvtXOAnwCPGmJtwXTbftdaq20WOWGVdI3+ev5mnFu0gIzme/7lkJN8s6KVWusgRCuk49+Ax6/MO\nmveLVr+vA05t39IkWvgDlvc3VzFn8Q7mr6/EAt+ZmM9N5wwiPVkXnRY5GjpDVcJqUdFufjl3LRsq\n6uiaEs/3T+vLFSf30qXsRI6Rwl3CYmdtI/8zbz0vrigjNyOJB64Yw9QROcTHaZRGkfagcJfjqmZ/\nCw8v2MKjH24lEICZZw7g+jMHkBSvURpF2pPCXY6L3fuaeHrxDh75YCs1+1u4eHRPfnLeIPp0TQl3\naSKepHCXDrWmtIbZC7fy6qpymv0BzhicxS3nD2Z4z/RwlybiaQp36RD7m/3c99ZGZi/cSnJ8HFeO\n78XVE/swIFvHqoscDwp3aXefFO3m9udWsW13A9+a0Jvbpg4hLVGHNIocTwp3aTcriqu5/61NLNhU\nRa8uSTz1gwlMGtAt3GWJRCWFuxwTf8Dy3sZKnvhkO+9trCIzuRM/mzqEqyf2ITleLy+RcNG7T45K\nY4ufR94v4qnFOyivaaRbagK3nD+Yayblk5qgl5VIuOldKEdsZXE1Nz+zgi1V9UwelMUvLxrG2UO7\n0ylWJyCJnCgU7hKyJp+fB98p5MH3tpDdOYEnrh3P6QOzwl2WiLRB4S4hWbJtDz97fjWFlfu45KRc\nfnnxcNKTdASMyIlK4S5faU99M/e+uZGnFu0gNyOJf3z3ZM4ckh3uskTkMBTu0qbGFj+Pf7SN/323\nkPomdxWkm84dRIq+LBWJCHqnyhf4A5aXVpTyxzc3UVq9n7OGZHP71CG6CpJIhFG4C+CuV/repiru\nfm0DGyrqGJGbxh8uG6WTkEQilMJdqG5o5o4X1vDq6nJ6d0nmz1eexIUjc3RpO5EIpnCPcgs37+In\nz65gT30zt5w/mOtO76cLZoh4gMI9SlXUNHLfWxt5ZmkJ/bNSmH3NyYzI1TC8Il6hcI8y+5p8zFqw\nhVkfFBEIwHWn9+XmcwfrSkgiHqNwjxKNLX6e/GQ7f31vC3vqm7lwVA63nj+E3l2Tw11a5AgEYNcm\naKyG7sMhQUcQyYlL4e5x1lrmrizj969toLymkdMGdOOn5w9mTK+McJd2eAE/1FVAwOduvkbYtxPq\ndkJjDaTlQEZvSO/lgjY2HoyBxlqoKYHq7bBzDVSshsoNkJkPfSe7W7eB0CnJPY+vGSpWQckS95xd\n+rlbwAeV66FqPZStgNKl7nkP6DoAsoZAYrp7/uSukD0Uuo+AjD4Qo+8uJHwU7h62e18Td764htfW\nVDAqL50/fmP08T+0ce92F5B1ZS6oa0uhptT9TMmGk78PQy+G2E7Q0gjFn8C2hVC8GEqXQXNd6M9l\nYl3A+/Z/cX5mXxfCuzfD5jc+nx+fCsld3M7C3/TVj5s1GIZ/HfLGuxCvWA3lK2D3FmiqC95aBX9s\nfHBnEwMxcZCaDZ1z3C09F9J6QnpvyD8NOiV++Tmrd8D6l2HzW5AzCibfCgmpof8vAKqLoXwl9Pva\n558yrIUdH8PuQhh5+Refu/RTWHg/nHoT5I0L7Tl2FUJ9ldt+MXGQ2QeSMj+/39cM2xe6/0/3EW5b\nHG6n529xO2VfE/ibwQYgLgnikyE2wb0mGmvcrWW/u/lbIKWb29mn9XR/t3+vuzXVQXO9u8WnuG2Q\n1tP9T2Li3C0+9cvbobkBmmo/nzYxEJfoGgWxndz/0lp33wm4Izf2QHHHWUFBgV26dGlYntvrrLW8\nsbaCO19cQ+1+HzedO4gZk/sRe7wObWysgbUvwsqnXZB8xkBq98/DrWI17N3m3mzZQ2H7xy6YTazr\n9ug13v2MTXBvwLh49/ep3SEhze0wqne4nUXzPmhpcIGQmg3peS48swa5lvUBNaWw/UOoKYb6Xe6W\nmu2eK288xCXA3q2wZ6v7FJA11LXQ4+IPv97N9e4Tws7VsKfIfQqwAVfTvp3BnVsZ7Ktw88Ht4E75\nLyi41oXk+rmwbq7bcQB06Q97tkBaLky9G4Zc6Ob7mtxjGONC58CnlgM2vw3Pfd9ti7hEGDwVug6E\n1c+69QMXtpf9w/2PVj0Lc2e6T0ex8TDtjzD2O4de18r18O5v3Q7oCwxkD4M+E6Fht6uj9Q66UzLk\njIF+Z7hb7lgXlOB2BCufhg/uddv1eEvKhM493f+zthT27wn9b2MT3M6nU7Jbn9h49zpurnc7/eZ6\nd19CZ/fa/dqtMOKSoyrTGPOptbbgsMsp3L2leE8Dv5y7lnc2VDIsJ437vjmaIT3Sju1BAwEItLhu\nCnCtn7ZUrIbFj7gAaWlwYTLmSsif7LpQUrt//kY+8LiFb8Gih13o9Z0M/c+C/FO93Z/t97mA37kO\nFv0NtsyHmE7ufwyQWwDDLnZB3rU/7FgEr97sWrNxSS6AOeh9m9odhl8CIy+DbR/A2792O8az7oTC\n+bD2BWjYBfmnw5hvuU8BL9/oWr2Dp8Ka56DPqXDhn+D122DLOzD2Gjj1RtdFZYwL36L3YNUcWPO8\na+1O/CH0PsWtk7/Jhf72j1wXV6dkGDwFBk9zO9Cda906FC9y3VxYt9M+0JKuLYeaHdBzLJzyQ0jp\n6v4vJsbt9JsbXIs8PtXtsBPTXCs6LsmFaX2l646rLXVhm5QJSRkuTONT3K2pDurK3eutuR6s39Xe\nWOPm15W7nXJ6rtuhJmV+vtMM+N3/vqXR1WFi3M0GPq+vZb+778D7JT41+PzJ7u+a6tyngXHfhQFn\nH9XLR+EeZay1zF64lXvf3EiMMdx87iC+OymfuGMdY335kzDvFhfWB6TluRZXzmg3f+9290VjxSrX\nShx5GYz7HuSO+2JrUtpWvgpWPOVCdMg0FywH8/tg2eOu1R2X5LoQTCxgXeiULYfNb7pgAdeFNP3B\nz3fEBwIspevnj1lbBs/PcDuDgu/DlLvdJ5SAH975DSy8zy2XkgU9Rrpum8YaSEiHcdfAaTe5bq22\nBA58qjjE9m/YA1vfd6+Z2jIXyjFxMPF6GHCOXjdfQeEeRWobW/jpMyt5c91OzhnanbumD6dnRtKx\nPWggAO/8t3uD55/uPkLHxAW/ZFzn+sP3bnUBk57n+joHngcnffvQb3jpWPurYcOr7tPRyG+EFpAB\nP+zaDNlDvnzfrs3B7z8WuR1Qzii30+h3ZmjdVNIhFO5RYn15Lf/vyU8p2bufn18wlO+dmo8JtdXj\na3JfOHXu8cX5TXXw0kxY96L7+HjBvV/sTmm9XFwSxOp7eZHjJdRw17sygr21bic/eno5aUlxzJlx\nCgX5h2gxv3Kz61MddL772N85B1b9n5u3fy+MuNT1zWb2dfPevNN9+Xfeb2DizEO3AL3cLy4S4RTu\nEehA//pv561nVG46j3yngOy0Ng6nA1j1DCyd7Y4E2fymC3VwfeMHgn7JbFj3kjsyZOdqdzTDN5+E\nvMM2DkTkBKVwjzD7m/3c9co6nl68g6kjenDf5WMOPXTArkJ45SboPRGuecXN2/GRa5UPOv/zQwQn\n3QAL7nZfcE37o/syNEbDEYhEMoV7BFldUsOP/285W6rq+eEZ/fnpeYMPPSxvSyP8+7vuELFLZ3/e\nL9538peX7dwDLry/w+oWkeNP4R4B/AHLQ+8V8qe3N9MtNYF//WACpw7o5o4nnnerO2524Lkw4Fx3\nEs72D10XTMVquOqZtg+tExFPCyncjTFTgAeAWODv1trft7HM5cCvcGdXrLTWXtWOdUatyrpGfjxn\nBR9t2c1Fo3vym+kjSDf17gSUTx9zhyBmD3fHSS/5u/sjE+uOQb/wftf9IiJR57DhboyJBR4EzgVK\ngCXGmLnW2nWtlhkI/Aw41Vq71xiT3VEFR5MPC3dx45wV7Gtq4Z5LR/GNUZmYpQ/Bhw+4U6MnzoQz\nf+5OVPE1uVP9Az7oNUFHsohEuVBa7uOBQmttEYAxZg4wHVjXapnrgAettXsBrLWV7V1oNLHWMuv9\nIu5+fQP9slJ56gfjGbR9Dvz5Hjf+SP+z4OxfQs8xn/9RXII70UhEhNDCPRcobjVdAkw4aJlBAMaY\nD3FdN7+y1r5+8AMZY2YAMwB69+59NPV6XmOLn9ufW8WLK8q4YGQP7r1kKMmv3eTG8+g7Gc78F/Q+\n+N8vIvJFoYR7W4djHHxaaxwwEDgDyAM+MMaMsNZWf+GPrJ0FzAJ3huoRV+txO2sbue6fS1lVUsNP\nzxvE9ad0w/zf5W7sjzPvgMm3aMwNEQlJKOFeAvRqNZ0HlLWxzCfW2hZgqzFmIy7sl7RLlVFgfXkt\nM/7xMec3vsaDw/bTq6IaHlrlumG+/jCMviLcJYpIBAkl3JcAA40xfYFS4Arg4CNhXgSuBB4zxnTD\nddMUtWehXrZgUxU//ddHPBBzP5NilkN5FzfcaM5omDTTXdBBROQIHDbcrbU+Y8xM4A1cf/qj1tq1\nxpi7gKXW2rnB+84zxqwD/MAt1trdHVm4Vzy7tJi7n/+QJ5P+yOBAoRtPu+B74S5LRCKcRoUMo0fe\nL2LOa2/zZMoD9GAX5rJH3XgvIiKHoFEhT2DWWu55fQNVC//BvMTHiY9PxXzzRXdpMhGRdqBwP84a\nW/zc8exSJq2/i9s6LcT2Pg1z6d/dZehERNqJwv042rWviRn/XMoZZbO4NG4h9mu3Yb52m0ZgFJF2\np3A/Toqq9nH17MWk1Rcxs9OrMPKbmDN/Hu6yRMSjFO7HQXnNfq6evZjGZh9v5D1LzN5kd5UjEZEO\nonDvSDWlNCx/ljuWdKN2fxavnVlK6nufuNEaUzW2moh0HIV7R9lfjf+Jr5O8ayOPAvVdhpGyqALy\nToax3w13dSLicTHhLsCTfM3UP3kVgV1buL7lx2wYcwcpyUnga4Rp90GM/u0i0rHUcm9v1rLtsR+Q\nX/ohv4i5gW99/waG9O8G3AqBgIJdRI4LhXt7aW6A9XMpnv8w+bXLeLbz1cyccSfZnRM/X0bBLiLH\nicL9WFkLH/8vvHc3NNcRCGTzcs4PuWTGb4mNVZiLSHgo3I+FvwVevRmW/ZO9eWdx/bbToPckHr92\ngoJdRMJK4X60GmvgmWug6F2qx/2Is5edSkZmIi98u4BOCnYRCTOF+9Hwt8DTV0LxIirP/CPTP+qL\nNX4eveZk0pM7hbs6ERGF+1GZ/2vY/iGlZz7A9A9ygQBPXXcK+d1Swl2ZiAigcD9y61+Gj/7C7qFX\nM21BTxLjYnjyBxMYkJ0a7spERD6jcD8Su7fAiz/E32MMlxZdREp8HHNmnEKvLsnhrkxE5Av0zV+o\ntn8Mj18EJoY/pP2cHbU+/nLVSQp2ETkhKdwPJ+CHBX+Axy6A2HgWT/4Hf1vl47++1p+xvTPDXZ2I\nSJsU7ofz/HXw7m9gxKXsvXo+178bYEiPztx4zsBwVyYickgK96+y7UNY8xxMvoXGi/7GTS9tobqh\nmfsuH0NCnK6eJCInLn2heijWwtu/gs45NEy4kf984lM+2LyL3319JMN6poW7OhGRr6RwP5SN86Bk\nMY1T7uOaJ1bz6fa9/OGyUXyjoFe4KxMROSyFe1sCfph/F7brAL63YjDLd1TzlyvHMm1UTrgrExEJ\nifrc27JyDlRt4N3c/+TjbTX87pKRCnYRiShquR+saR+8+1uaskdz/fJenDG4K98YlxfuqkREjoha\n7gdb8HuoLeW3gWuIjYnhd18fiTEm3FWJiBwRhXtrO9fCx3+lMO9S/lnSgzumDaVnRlK4qxIROWIK\n9wMCAXjlZmxiBteWTmNiv65ccbKOjBGRyKRwP2DFv6D4E97Pv4Ht+xO5dcpgdceISMTSF6oADXvg\nrV8Q6HUKtxUOZ2K/NE7SuDEiEsHUcgd493fQWM1b+bdQUdfC/zujf7grEhE5Jmq5V6yGpbMJFFzL\n3cvjGN4zgdMHdgt3VSIixyS6W+7Wwmu3QWIG7/S4lqJd9fzwjAHqaxeRiBfdLfc1z8H2D+HCP/GX\nj3eT3zWZKSN6hLsqEZFjFlLL3RgzxRiz0RhTaIy5/SuWu8wYY40xBe1XYgfxt8Bbv4Cc0RTmXcLK\nkhq+MzGf2Bi12kUk8h023I0xscCDwFRgGHClMWZYG8t1Bn4ELGrvIjvEtoVQWwqTb2He2koALhip\n8WNExBtCabmPBwqttUXW2mZgDjC9jeX+G7gHaGzH+jrOupcgPhUGnMO81eUU9MmkR3piuKsSEWkX\noYR7LlDcarokOO8zxpiTgF7W2le+6oGMMTOMMUuNMUurqqqOuNh2E/DD+pdh0PlsqfazoaJOrXYR\n8ZRQwr2tTmj72Z3GxAD3Az853ANZa2dZawustQVZWVmhV9netn8EDbtg2HTmrSoHYOpIfZEqIt4R\nSriXAK0HWckDylpNdwZGAO8ZY7YBpwBzT+gvVde9BJ2SYcC5vLq6nHF9MslJ1wBhIuIdoYT7EmCg\nMaavMSYeuAKYe+BOa22NtbabtTbfWpsPfAJcbK1d2iEVH6tAANbPhYHnUlQTUJeMiHjSYcPdWusD\nZgJvAOuBZ6y1a40xdxljLu7oAttd8SLYtxOGXsy81a5L5gJ1yYiIx4R0EpO1dh4w76B5vzjEsmcc\ne1kdaN1LEJsAg85n3jsrGNs7Q10yIuI50TX8wIEumQHnUFgD68prmTaqZ7irEhFpd9EV7uXL3YlL\nQy/i+WWlxMYYLhqt/nYR8Z7oCveNr4OJITDgPF5YXsrkgd3I7qwTl0TEe6Ir3De9Br0m8EmFpbym\nkUvG5oW7IhGRDhE94V5T4sZuHzSF55aV0jkhjnOHdQ93VSIiHSJ6wn3T6wDs73cur60pZ9qoHBI7\nxYa5KBGRjhFF4f4GZPbl9Yo0Gpr96pIREU+LjnBvroeiBTBoCs8vL6NXlyQK+ugC2CLiXdER7kXv\ngb+Jvb3OYmHhLr4+JpcYXZRDRDwsOsJ942uQkMbcvflYCxePyT3834iIRDDvh3sg4PrbB5zNS6ur\nGJqTxoDs1HBXJSLSobwf7iWLob6SPXlnsWxHNReO0hmpIuJ93g/3lXMgLokXG8YAcJHGkhGRKODt\ncPc1wdoXYOiFvLCultF56fTumhzuqkREOpy3w33TG9BYTUX+11ldWsOFarWLSJTwdrivnAOp3Xmu\nuh8A09TfLiJRwrvh3rAHNr8JI7/By6urKOiTSc8MXZRDRKKDd8N9zXMQaKG0z3Q2VNSp1S4iUcW7\n4b5yDmQPZ/F+188+sX/XMBckInL8eDPc9xRB6VIY/U1WFteQ1CmWAVk6cUlEooc3w71yvfuZfxqr\nSqoZkZtGXKw3V1VEpC3eTLzaMgBaUnuytqyWUXkZYS5IROT48mi4l0JMHJv2JdLkCzAqLz3cFYmI\nHFceDfcy6NyTVaV1AIxWy11Eoox3wz0th1Ul1aQndaKPhhwQkSjj4XDvycriGkblpWOMLswhItHF\ne+FuLdSW4UvNYePOOvW3i0hU8l64798Lvv2U2y74A1ZHyohIVPJeuAcPgyxsTAP0ZaqIRCfPhvvK\nmhSyOyfQIz0xzAWJiBx/Hgz3UgA+3pWgLhkRiVreC/e6cqyJ4dPd8YzWl6kiEqW8F+61pbQkZuEj\njhEKdxGJUh4M9zJq47MBNBKkiEStkMLdGDPFGLPRGFNojLm9jftvNsasM8asMsbMN8b0af9SQ1Rb\nxu6YLsTHxejKSyIStQ4b7sZnCHTXAAAIQklEQVSYWOBBYCowDLjSGDPsoMWWAwXW2lHAv4F72rvQ\nkNWWUeLvQp8uycTG6MxUEYlOobTcxwOF1toia20zMAeY3noBa+271tqG4OQnQF77lhmipjpoqqWo\nKY2+3VLCUoKIyIkglHDPBYpbTZcE5x3KtcBrbd1hjJlhjFlqjFlaVVUVepWhqi0HYH19Z4W7iES1\nUMK9rb4N2+aCxnwbKAD+0Nb91tpZ1toCa21BVlZW6FWGKniMe4k/U+EuIlEtLoRlSoBerabzgLKD\nFzLGnAPcAXzNWtvUPuUdoeDZqeV0IV/hLiJRLJSW+xJgoDGmrzEmHrgCmNt6AWPMScDDwMXW2sr2\nLzNEwXCvtJn0U7iLSBQ7bLhba33ATOANYD3wjLV2rTHmLmPMxcHF/gCkAs8aY1YYY+Ye4uE6Vm0p\n9XEZxMUnkdU5ISwliIicCELplsFaOw+Yd9C8X7T6/Zx2rit01sKBi3HUllEV0438bim6QIeIRLXI\nPkO18G24dxBUbXTTtWWU+TPU3y4iUS+yw33Tm1BfCc/9AHzN2LoytjZnqL9dRKJeZId7yWJIyYKK\nVfD2LzENuykLdCG/q8JdRKJbSH3uJ6SW/VCxGib9COqr4JO/AlBhu3B2lsJdRKJb5IZ72QoI+CDv\nZOg7GbYthL1bKacLfdVyF5EoF7ndMiVL3M+8kyEhFS6dzdaU0RTHDyAzJT68tYmIhFnkttxLFkNm\nPqQGhzHIG8edmffQNdkf1rJERE4EkdlytxaKl7hWeytbq+o1poyICJEa7jUlsK8C8sZ/NquxxU9Z\nTaPCXUSESA33z/rbCwDw+QMs2bYHQCcwiYgQqX3uJUshLpFlLXncfv8Ciqrq8QXcKMRDenQOc3Ei\nIuEXoeG+mEDOGG5+di3NvgAzJvejb7cUhuakMai7wl1EJPLC3dcE5StZnH0523Y38NR1E5jUv1u4\nqxIROaFEXp97+SrwN/P4jiyuOLmXgl1EpA0R13L3Fy8mFtiWNIw5U4eGuxwRkRNSxIX7C3v6srHl\nKm687HTSkzuFuxwRkRNSxIX7yRO/RnnSQKaMyAl3KSIiJ6yI63Pv0zWFG84eGO4yREROaBEX7iIi\ncngKdxERD1K4i4h4kMJdRMSDFO4iIh6kcBcR8SCFu4iIByncRUQ8SOEuIuJBCncREQ9SuIuIeJDC\nXUTEgxTuIiIepHAXEfEghbuIiAcp3EVEPEjhLiLiQSGFuzFmijFmozGm0Bhzexv3Jxhj/i94/yJj\nTH57FyoiIqE7bLgbY2KBB4GpwDDgSmPMsIMWuxbYa60dANwP3N3ehYqISOhCabmPBwqttUXW2mZg\nDjD9oGWmA48Hf/83cLYxxrRfmSIiciTiQlgmFyhuNV0CTDjUMtZanzGmBugK7Gq9kDFmBjAjOLnP\nGLPxaIoGuh382FEiGtc7GtcZonO9o3Gd4cjXu08oC4US7m21wO1RLIO1dhYwK4Tn/OqCjFlqrS04\n1seJNNG43tG4zhCd6x2N6wwdt96hdMuUAL1aTecBZYdaxhgTB6QDe9qjQBEROXKhhPsSYKAxpq8x\nJh64Aph70DJzgWuCv18GvGOt/VLLXUREjo/DdssE+9BnAm8AscCj1tq1xpi7gKXW2rnAbOAJY0wh\nrsV+RUcWTTt07USoaFzvaFxniM71jsZ1hg5ab6MGtoiI9+gMVRERD1K4i4h4UMSF++GGQvACY0wv\nY8y7xpj1xpi1xpgbg/O7GGPeMsZsDv7MDHet7c0YE2uMWW6MeSU43Tc4pMXm4BAX8eGusb0ZYzKM\nMf82xmwIbvOJUbKtbwq+vtcYY542xiR6bXsbYx41xlQaY9a0mtfmtjXOn4PZtsoYM/ZYnjuiwj3E\noRC8wAf8xFo7FDgFuD64nrcD8621A4H5wWmvuRFY32r6buD+4DrvxQ114TUPAK9ba4cAo3Hr7+lt\nbYzJBX4EFFhrR+AO1rgC723vx4ApB8071LadCgwM3mYADx3LE0dUuBPaUAgRz1pbbq1dFvy9Dvdm\nz+WLwzw8DvxHeCrsGMaYPGAa8PfgtAHOwg1pAd5c5zRgMu6IM6y1zdbaajy+rYPigKTguTHJQDke\n297W2vf58jk/h9q204F/WucTIMMYk3O0zx1p4d7WUAi5YarluAiOsHkSsAjobq0tB7cDALLDV1mH\n+BNwKxAITncFqq21vuC0F7d3P6AK+EewO+rvxpgUPL6trbWlwL3ADlyo1wCf4v3tDYfetu2ab5EW\n7iENc+AVxphU4Dngx9ba2nDX05GMMRcCldbaT1vPbmNRr23vOGAs8JC19iSgHo91wbQl2M88HegL\n9ARScN0SB/Pa9v4q7fp6j7RwD2UoBE8wxnTCBfu/rLXPB2fvPPAxLfizMlz1dYBTgYuNMdtw3W1n\n4VryGcGP7eDN7V0ClFhrFwWn/40Ley9va4BzgK3W2iprbQvwPDAJ729vOPS2bdd8i7RwD2UohIgX\n7GueDay31t7X6q7WwzxcA7x0vGvrKNban1lr86y1+bjt+o619lvAu7ghLcBj6wxgra0Aio0xg4Oz\nzgbW4eFtHbQDOMUYkxx8vR9Yb09v76BDbdu5wHeCR82cAtQc6L45KtbaiLoBFwCbgC3AHeGup4PW\n8TTcx7FVwIrg7QJcH/R8YHPwZ5dw19pB638G8Erw937AYqAQeBZICHd9HbC+Y4Clwe39IpAZDdsa\n+DWwAVgDPAEkeG17A0/jvlNowbXMrz3UtsV1yzwYzLbVuCOJjvq5NfyAiIgHRVq3jIiIhEDhLiLi\nQQp3EREPUriLiHiQwl1ExIMU7iIiHqRwFxHxoP8PcoQthGMhQcsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f8a1f500320>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lr = 2e-4\n",
    "epochs = 100\n",
    "\n",
    "a, b = [], []\n",
    "for epoch in range(epochs):\n",
    "    if epoch > 80:\n",
    "        lr = 5e-5\n",
    "    train_loss = 0.\n",
    "    train_acc = 0.\n",
    "    start = time()\n",
    "    for data, label in train_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            l = loss(output, label)\n",
    "        l.backward()\n",
    "        for param in params:\n",
    "            param[:] = param - lr * param.grad\n",
    "        train_loss += nd.mean(l).asscalar()\n",
    "        train_acc += accuracy(output, label)\n",
    "    test_acc = evaluate_accuracy(test_data, net, ctx)\n",
    "\n",
    "    if epoch%10 == 0:\n",
    "        print(\"E %d; L %f; Tr_acc %f; Te_acc %f; T %f\" % (\n",
    "            epoch, train_loss / batch, train_acc / batch, test_acc, time() - start))\n",
    "    a.append(train_acc/batch)\n",
    "    b.append(test_acc)\n",
    "    \n",
    "print(\"Tr_acc %f; Te_acc %f\" % (train_acc / batch, test_acc))\n",
    "plt.plot(np.arange(epochs), a, np.arange(epochs), b)\n",
    "plt.ylim(0,1)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
